4 Main Results
Interpret and summarize the prediction and stability results.
Insert narrative here.
Evaluate pipeline on test data.
Summarize test set prediction and/or interpretability results.
Insert narrative here.
tidymodels
caret
# how to do cross validation
trcontrol <- caret::trainControl(
method = "cv",
number = 5,
classProbs = if (is.factor(ytrain)) TRUE else FALSE,
summaryFunction = caret::defaultSummary,
allowParallel = FALSE,
verboseIter = FALSE
)
response <- "raw"
model_list <- list(
ranger = list(tuneGrid = expand.grid(mtry = seq(sqrt(ncol(Xtrain)),
ncol(Xtrain) / 3,
length.out = 3),
splitrule = "gini",
min.node.size = 1),
importance = "impurity",
num.threads = 1),
xgbTree = list(tuneGrid = expand.grid(nrounds = c(10, 25, 50, 100, 150),
max_depth = c(3, 6),
colsample_bytree = 0.33,
eta = c(0.1, 0.3),
gamma = 0,
min_child_weight = 1,
subsample = 0.6),
nthread = 1)
)
model_fits <- list()
model_preds <- list()
model_errs <- list()
model_vimps <- list()
for (model_name in names(model_list)) {
mod <- model_list[[model_name]]
if (identical(mod, list())) {
mod <- NULL
}
mod_fit <- do.call(caret::train, args = c(list(x = as.data.frame(Xtrain),
y = ytrain,
trControl = trcontrol,
method = model_name),
mod))
model_fits[[model_name]] <- mod_fit
model_preds[[model_name]] <- predict(mod_fit, as.data.frame(Xtest),
type = response)
model_errs[[model_name]] <- caret::postResample(
pred = model_preds[[model_name]], obs = ytest
)
model_vimps[[model_name]] <- caret::varImp(mod_fit)
}
model_preds <- dplyr::bind_rows(model_preds, .id = "model")
model_errs <- dplyr::bind_rows(model_errs, .id = "model")
model_vimps <- purrr::map_dfr(model_vimps,
~.x[["importance"]] %>%
tibble::rownames_to_column("variable"),
.id = "model")